/**
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import { ActionMetadata, modelActionMetadata, z } from 'genkit';
import {
  GenerationCommonConfigSchema,
  ModelAction,
  ModelInfo,
  ModelReference,
  modelRef,
} from 'genkit/model';
import { model as pluginModel } from 'genkit/plugin';
import { imagenPredict } from './client.js';
import { fromImagenResponse, toImagenPredictRequest } from './converters.js';
import { ClientOptions, Model, VertexPluginOptions } from './types.js';
import {
  calculateRequestOptions,
  checkModelName,
  extractVersion,
  modelName,
} from './utils.js';

/**
 * See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api.
 */
export const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
  // TODO: Remove common config schema extension since Imagen models don't support
  // most of the common config parameters. Also, add more parameters like sampleCount
  // from the above reference.
  language: z
    .enum(['auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN'])
    .describe('Language of the prompt text.')
    .optional(),
  aspectRatio: z
    .enum(['1:1', '9:16', '16:9', '3:4', '4:3'])
    .describe('Desired aspect ratio of the output image.')
    .optional(),
  negativePrompt: z
    .string()
    .describe(
      'A description of what to discourage in the generated images. ' +
        'For example: "animals" (removes animals), "blurry" ' +
        '(makes the image clearer), "text" (removes text), or ' +
        '"cropped" (removes cropped images).'
    )
    .optional(),
  seed: z
    .number()
    .int()
    .min(1)
    .max(2147483647)
    .describe(
      'Controls the randomization of the image generation process. Use the ' +
        'same seed across requests to provide consistency, or change it to ' +
        'introduce variety in the response.'
    )
    .optional(),
  location: z
    .string()
    .describe('Google Cloud region e.g. us-central1.')
    .optional(),
  personGeneration: z
    .enum(['dont_allow', 'allow_adult', 'allow_all'])
    .describe('Control if/how images of people will be generated by the model.')
    .optional(),
  safetySetting: z
    .enum(['block_most', 'block_some', 'block_few', 'block_fewest'])
    .describe('Adds a filter level to safety filtering.')
    .optional(),
  addWatermark: z
    .boolean()
    .describe('Add an invisible watermark to the generated images.')
    .optional(),
  storageUri: z
    .string()
    .describe('Cloud Storage URI to store the generated images.')
    .optional(),
  mode: z
    .enum(['upscale'])
    .describe('Mode must be set for upscaling requests.')
    .optional(),
  /**
   * Describes the editing intention for the request.
   *
   * See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#edit_images_2 for details.
   */
  editConfig: z
    .object({
      editMode: z
        .enum([
          'inpainting-insert',
          'inpainting-remove',
          'outpainting',
          'product-image',
        ])
        .describe('Editing intention for the request.')
        .optional(),
      maskMode: z
        .object({
          maskType: z
            .enum(['background', 'foreground', 'semantic'])
            .describe(
              '"background" automatically generates a mask for all ' +
                'regions except the primary subject(s) of the image, ' +
                '"foreground" automatically generates a mask for the primary ' +
                'subjects(s) of the image. "semantic" segments one or more ' +
                'of the segmentation classes using class ID.'
            ),
          classes: z
            .array(z.number())
            .describe('List of class IDs for segmentation.')
            .length(5)
            .optional(),
        })
        .describe(
          'Prompts the model to generate a mask instead of you ' +
            'needing to provide one. Consequently, when you provide ' +
            'this parameter you can omit a mask object.'
        )
        .passthrough()
        .optional(),
      maskDilation: z
        .number()
        .describe('Dilation percentage of the mask provided.')
        .min(0.0)
        .max(1.0)
        .optional(),
      guidanceScale: z
        .number()
        .describe(
          'Controls how much the model adheres to the text prompt. ' +
            'Large values increase output and prompt alignment, but may ' +
            'compromise image quality. Suggested values are 0-9 ' +
            '(low strength), 10-20 (medium strength), 21+ (high strength).'
        )
        .optional(),
      productPosition: z
        .enum(['reposition', 'fixed'])
        .describe(
          'Defines whether the product should stay fixed or be ' +
            'repositioned.'
        )
        .optional(),
    })
    .passthrough()
    .optional(),
  upscaleConfig: z
    .object({
      upscaleFactor: z
        .enum(['x2', 'x4'])
        .describe('The factor to upscale the image.'),
    })
    .describe('Configuration for upscaling.')
    .passthrough()
    .optional(),
}).passthrough();
export type ImagenConfigSchemaType = typeof ImagenConfigSchema;
export type ImagenConfig = z.infer<ImagenConfigSchemaType>;

// for commonRef
type ConfigSchemaType = ImagenConfigSchemaType;

function commonRef(
  name: string,
  info?: ModelInfo,
  configSchema: ConfigSchemaType = ImagenConfigSchema
): ModelReference<ConfigSchemaType> {
  return modelRef({
    name: `vertexai/${name}`,
    configSchema,
    info: info ?? {
      supports: {
        media: true,
        multiturn: false,
        tools: false,
        toolChoice: false,
        systemRole: false,
        output: ['media'],
      },
    },
  });
}

// Allow all the capabilities for unknown future models
const GENERIC_MODEL = commonRef('imagen', {
  supports: {
    media: true,
    multiturn: true,
    tools: true,
    systemRole: true,
    output: ['media'],
  },
});

export const KNOWN_MODELS = {
  'imagen-3.0-generate-002': commonRef('imagen-3.0-generate-002'),
  'imagen-3.0-generate-001': commonRef('imagen-3.0-generate-001'),
  'imagen-3.0-capability-001': commonRef('imagen-3.0-capability-001'),
  'imagen-3.0-fast-generate-001': commonRef('imagen-3.0-fast-generate-001'),
  'imagen-4.0-generate-preview-06-06': commonRef(
    'imagen-4.0-generate-preview-06-06'
  ),
  'imagen-4.0-ultra-generate-preview-06-06': commonRef(
    'imagen-4.0-ultra-generate-preview-06-06'
  ),
} as const;
export type KnownModels = keyof typeof KNOWN_MODELS;
export type ImagenModelName = `imagen=${string}`;
export function isImagenModelName(value?: string): value is ImagenModelName {
  return !!value?.startsWith('imagen-');
}

export function model(
  version: string,
  config: ImagenConfig = {}
): ModelReference<typeof ImagenConfigSchema> {
  const name = checkModelName(version);
  if (KNOWN_MODELS[name]) {
    return KNOWN_MODELS[name].withConfig(config);
  }
  return modelRef({
    name: `vertexai/${name}`,
    config,
    configSchema: ImagenConfigSchema,
    info: {
      ...GENERIC_MODEL.info,
    },
  });
}

export function listActions(models: Model[]): ActionMetadata[] {
  return models
    .filter((m: Model) => isImagenModelName(modelName(m.name)))
    .map((m: Model) => {
      const ref = model(m.name);
      return modelActionMetadata({
        name: ref.name,
        info: ref.info,
        configSchema: ref.configSchema,
      });
    });
}

export function listKnownModels(
  clientOptions: ClientOptions,
  pluginOptions?: VertexPluginOptions
) {
  return Object.keys(KNOWN_MODELS).map((name: string) =>
    defineModel(name, clientOptions, pluginOptions)
  );
}

export function defineModel(
  name: string,
  clientOptions: ClientOptions,
  pluginOptions?: VertexPluginOptions
): ModelAction {
  const ref = model(name);

  return pluginModel(
    {
      name: ref.name,
      ...ref.info,
      configSchema: ref.configSchema,
    },
    async (request, { abortSignal }) => {
      const clientOpt = calculateRequestOptions(
        { ...clientOptions, signal: abortSignal },
        request.config
      );

      const imagenPredictRequest = toImagenPredictRequest(request);

      const response = await imagenPredict(
        extractVersion(ref),
        imagenPredictRequest,
        clientOpt
      );

      if (!response.predictions || response.predictions.length == 0) {
        throw new Error(
          'Model returned no predictions. Possibly due to content filters.'
        );
      }

      return fromImagenResponse(response, request);
    }
  );
}

export const TEST_ONLY = {
  GENERIC_MODEL,
  KNOWN_MODELS,
};
